{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# TensorFlow1 pb 推理\n",
    "\n",
    "参考：[migrating_checkpoints](https://www.tensorflow.org/guide/migrate/migrating_checkpoints)\n",
    "\n",
    "下面以模型 [resnet_v2_50](http://download.tensorflow.org/models/resnet_v2_50_2017_04_14.tar.gz) 为例展示。\n",
    "\n",
    "需要克隆项目 [models](https://github.com/tensorflow/models)，然后执行如下操作。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-01-06 18:32:24.578169: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA\n",
      "To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
      "2024-01-06 18:32:24.722915: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory\n",
      "2024-01-06 18:32:24.722945: I tensorflow/stream_executor/cuda/cudart_stub.cc:29] Ignore above cudart dlerror if you do not have a GPU set up on your machine.\n",
      "2024-01-06 18:32:24.758001: E tensorflow/stream_executor/cuda/cuda_blas.cc:2981] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n",
      "2024-01-06 18:32:25.501636: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory\n",
      "2024-01-06 18:32:25.501734: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory\n",
      "2024-01-06 18:32:25.501747: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.\n"
     ]
    }
   ],
   "source": [
    "import tensorflow as tf\n",
    "try:\n",
    "    tf1 = tf.compat.v1\n",
    "except (ImportError, AttributeError): \n",
    "    tf1 = tf"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "切换到 [`models/research/slim`](https://github.com/tensorflow/models/tree/master/research/slim) 目录下："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/media/pc/data/lxw/ai/tasks/models/research/slim\n"
     ]
    }
   ],
   "source": [
    "%cd /media/pc/data/lxw/ai/tasks/models/research/slim"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "图像预处理："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "\n",
    "\n",
    "def preprocessing(\n",
    "    image,\n",
    "    use_grayscale=False,\n",
    "    central_fraction=0.875,\n",
    "    central_crop=True,\n",
    "    height=299,\n",
    "    width=299,\n",
    "    mean: tuple[float, ...] = (0.485, 0.456, 0.406),\n",
    "    std: tuple[float, ...] = (1, 1, 1)\n",
    "):\n",
    "    # image = tf.constant(image)\n",
    "    if image.dtype != tf.float32:\n",
    "        image = tf.image.convert_image_dtype(image, dtype=tf.float32)\n",
    "    if use_grayscale:\n",
    "        image = tf.image.rgb_to_grayscale(image)\n",
    "    if central_crop and central_fraction:\n",
    "        image = tf.image.central_crop(image, central_fraction=central_fraction)\n",
    "    if height and width:\n",
    "        image = tf.expand_dims(image, 0)\n",
    "        image = tf.image.resize(image, [height, width],\n",
    "                                method='bilinear',\n",
    "                                preserve_aspect_ratio=False,\n",
    "                                antialias=False)\n",
    "        image = tf.squeeze(image, [0])\n",
    "    image = tf.subtract(image, mean)\n",
    "    image = tf.divide(image, std)\n",
    "    return image"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "将 ckpt 模型转换为 pb 模型："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "try:\n",
    "    tf1 = tf.compat.v1\n",
    "except (ImportError, AttributeError):\n",
    "    tf1 = tf\n",
    "from PIL import Image\n",
    "import numpy as np\n",
    "from nets import resnet_v2\n",
    "import tf_slim as slim\n",
    "import shutil\n",
    "from tvm_book.data.classification import ImageFolderDataset\n",
    "\n",
    "def remove_dir(path):\n",
    "    try:\n",
    "        shutil.rmtree(path)\n",
    "    except:\n",
    "        ...\n",
    "tf.get_logger().setLevel('ERROR')\n",
    "\n",
    "# 预处理\n",
    "root = \"/media/pc/data/lxw/home/data/datasets/ILSVRC/val\"\n",
    "valset = ImageFolderDataset(root)\n",
    "image, label_id = valset[1001]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/media/pc/data/tmp/cache/conda/envs/tvmz/lib/python3.10/site-packages/tensorflow/python/keras/engine/base_layer_v1.py:1694: UserWarning: `layer.apply` is deprecated and will be removed in a future version. Please use `layer.__call__` method instead.\n",
      "  warnings.warn('`layer.apply` is deprecated and '\n",
      "2024-01-06 18:33:42.431432: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory\n",
      "2024-01-06 18:33:42.431543: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcublas.so.11'; dlerror: libcublas.so.11: cannot open shared object file: No such file or directory\n",
      "2024-01-06 18:33:42.431629: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcublasLt.so.11'; dlerror: libcublasLt.so.11: cannot open shared object file: No such file or directory\n",
      "2024-01-06 18:33:42.438100: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcusparse.so.11'; dlerror: libcusparse.so.11: cannot open shared object file: No such file or directory\n",
      "2024-01-06 18:33:42.438164: W tensorflow/core/common_runtime/gpu/gpu_device.cc:1934] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform.\n",
      "Skipping registering GPU devices...\n",
      "2024-01-06 18:33:42.438711: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA\n",
      "To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
      "2024-01-06 18:33:42.474363: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:354] MLIR V1 optimization pass is not enabled\n"
     ]
    }
   ],
   "source": [
    "import set_env # 加载 TVM\n",
    "import tvm.relay.testing.tf as tf_testing\n",
    "model_dir = 'temp/resnet_v2_50'\n",
    "remove_dir(model_dir)\n",
    "checkpoints_path = tf_testing.get_workload_official(\n",
    "    \"http://download.tensorflow.org/models/resnet_v2_50_2017_04_14.tar.gz\",\n",
    "    \"resnet_v2_50.ckpt\"\n",
    ")\n",
    "with tf1.Graph().as_default() as graph:\n",
    "    processed_image = preprocessing(\n",
    "        image,\n",
    "        use_grayscale=False,\n",
    "        central_fraction=0.875,\n",
    "        central_crop=True,\n",
    "        height=299,\n",
    "        width=299,\n",
    "        mean=(0.485, 0.456, 0.406),\n",
    "        std=(1, 1, 1)\n",
    "    )\n",
    "    processed_images  = tf.expand_dims(processed_image, 0)\n",
    "    # 创建模型时，使用默认的参数范围（arg scope）来配置批归一化（batch norm）参数。\n",
    "    with slim.arg_scope(resnet_v2.resnet_arg_scope()):\n",
    "        logits, end_points = resnet_v2.resnet_v2_50(processed_images, num_classes=1001,\n",
    "                                                    global_pool=True,\n",
    "                                                    is_training=False)\n",
    "    probabilities = tf.nn.softmax(logits)\n",
    "    variables = slim.get_model_variables('resnet_v2_50')\n",
    "    init_fn = slim.assign_from_checkpoint_fn(checkpoints_path, variables)\n",
    "    with tf1.Session() as sess:\n",
    "        init_fn(sess)\n",
    "        # np_probabilities, np_processed_images = sess.run([probabilities, processed_images])\n",
    "        np_probabilities = sess.run(probabilities)\n",
    "        tf1.saved_model.simple_save(\n",
    "            sess, model_dir,\n",
    "            inputs={'inputs': processed_images},\n",
    "            outputs={'output': probabilities}\n",
    "        )"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "加载保存的模型："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "try:\n",
    "    tf1 = tf.compat.v1\n",
    "except (ImportError, AttributeError):\n",
    "    tf1 = tf\n",
    "from PIL import Image\n",
    "import numpy as np\n",
    "tf.get_logger().setLevel('ERROR')\n",
    "\n",
    "# 预处理\n",
    "root = \"/media/pc/data/lxw/home/data/datasets/ILSVRC/val\"\n",
    "valset = ImageFolderDataset(root)\n",
    "image, label_id = valset[1001]\n",
    "model_dir = 'temp/resnet_v2_50'\n",
    "# remove_dir(model_dir)\n",
    "processed_image = preprocessing(\n",
    "    image,\n",
    "    use_grayscale=False,\n",
    "    central_fraction=0.875,\n",
    "    central_crop=True,\n",
    "    height=299,\n",
    "    width=299,\n",
    "    mean=(0.485, 0.456, 0.406),\n",
    "    std=(1, 1, 1)\n",
    ")\n",
    "np_processed_images = np.expand_dims(processed_image.numpy(), axis=0)\n",
    "# 加载模型\n",
    "loaded_model = tf.saved_model.load(model_dir)\n",
    "loaded_model = loaded_model.signatures[tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY]\n",
    "out = loaded_model(tf.constant(np_processed_images))['output'].numpy()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "打印标签信息："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tvm_book.data.imagenet.classification import ImageNet1kAttr\n",
    "\n",
    "imagenet1k_attr = ImageNet1kAttr()\n",
    "sorted_inds = np_probabilities[0].argsort()[::-1]\n",
    "topk = 5\n",
    "print(f\"真实标签：{imagenet1k_attr.classes_long[label_id]}\")\n",
    "for sorted_ind in sorted_inds[:topk]:\n",
    "    label = imagenet1k_attr.classes_long[sorted_ind-1]\n",
    "    print(f\"{sorted_ind-1}: {label.ljust(38)}\\t{np_probabilities[0, sorted_ind]}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "tvmz",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
